{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d1b5b77",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install transformers datasets evaluate\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10e89d62",
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c0e38bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "CACHE_DIR = os.getenv('CACHE_DIR')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ce6603f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "food = load_dataset(\"food101\", split=\"train\", cache_dir=CACHE_DIR)\n",
    "food"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2936aa74",
   "metadata": {},
   "outputs": [],
   "source": [
    "food = food.train_test_split(test_size=0.001)\n",
    "food"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b89161d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = food[\"train\"].features[\"label\"].names\n",
    "label2id, id2label = dict(), dict()\n",
    "for i, label in enumerate(labels):\n",
    "    label2id[label] = str(i)\n",
    "    id2label[str(i)] = label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3699b52",
   "metadata": {},
   "outputs": [],
   "source": [
    "id2label[str(79)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6793fde",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoImageProcessor\n",
    "\n",
    "checkpoint = \"google/vit-base-patch16-224-in21k\"\n",
    "image_processor = AutoImageProcessor.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3318524",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "size = (image_processor.size[\"height\"], image_processor.size[\"width\"])\n",
    "\n",
    "train_data_augmentation = keras.Sequential(\n",
    "    [\n",
    "        layers.RandomCrop(size[0], size[1]),\n",
    "        layers.Rescaling(scale=1.0 / 127.5, offset=-1),\n",
    "        layers.RandomFlip(\"horizontal\"),\n",
    "        layers.RandomRotation(factor=0.02),\n",
    "        layers.RandomZoom(height_factor=0.2, width_factor=0.2),\n",
    "    ],\n",
    "    name=\"train_data_augmentation\",\n",
    ")\n",
    "\n",
    "val_data_augmentation = keras.Sequential(\n",
    "    [\n",
    "        layers.CenterCrop(size[0], size[1]),\n",
    "        layers.Rescaling(scale=1.0 / 127.5, offset=-1),\n",
    "    ],\n",
    "    name=\"val_data_augmentation\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20f22425",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "def convert_to_tf_tensor(image: Image):\n",
    "    np_image = np.array(image)\n",
    "    tf_image = tf.convert_to_tensor(np_image)\n",
    "    # `expand_dims()` is used to add a batch dimension since\n",
    "    # the TF augmentation layers operates on batched inputs.\n",
    "    return tf.expand_dims(tf_image, 0)\n",
    "\n",
    "\n",
    "def preprocess_train(example_batch):\n",
    "    \"\"\"Apply train_transforms across a batch.\"\"\"\n",
    "    images = [\n",
    "        train_data_augmentation(convert_to_tf_tensor(image.convert(\"RGB\"))) for image in example_batch[\"image\"]\n",
    "    ]\n",
    "    example_batch[\"pixel_values\"] = [tf.transpose(tf.squeeze(image)) for image in images]\n",
    "    return example_batch\n",
    "\n",
    "\n",
    "def preprocess_val(example_batch):\n",
    "    \"\"\"Apply val_transforms across a batch.\"\"\"\n",
    "    images = [\n",
    "        val_data_augmentation(convert_to_tf_tensor(image.convert(\"RGB\"))) for image in example_batch[\"image\"]\n",
    "    ]\n",
    "    example_batch[\"pixel_values\"] = [tf.transpose(tf.squeeze(image)) for image in images]\n",
    "    return example_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e66591d",
   "metadata": {},
   "outputs": [],
   "source": [
    "food[\"train\"].set_transform(preprocess_train)\n",
    "food[\"test\"].set_transform(preprocess_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89843967",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DefaultDataCollator\n",
    "\n",
    "data_collator = DefaultDataCollator(return_tensors=\"tf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fa86c2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import create_optimizer\n",
    "\n",
    "batch_size = 16\n",
    "num_epochs = 2\n",
    "num_train_steps = len(food[\"train\"]) * num_epochs\n",
    "learning_rate = 3e-5\n",
    "weight_decay_rate = 0.01\n",
    "\n",
    "optimizer, lr_schedule = create_optimizer(\n",
    "    init_lr=learning_rate,\n",
    "    num_train_steps=num_train_steps,\n",
    "    weight_decay_rate=weight_decay_rate,\n",
    "    num_warmup_steps=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1eb2cc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TFAutoModelForImageClassification\n",
    "\n",
    "model = TFAutoModelForImageClassification.from_pretrained(\n",
    "    checkpoint,\n",
    "    id2label=id2label,\n",
    "    label2id=label2id,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ab33bfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# converting our train dataset to tf.data.Dataset\n",
    "tf_train_dataset = food[\"train\"].to_tf_dataset(\n",
    "    columns=\"pixel_values\", label_cols=\"label\", shuffle=True, batch_size=batch_size, collate_fn=data_collator\n",
    ")\n",
    "\n",
    "# converting our test dataset to tf.data.Dataset\n",
    "tf_eval_dataset = food[\"test\"].to_tf_dataset(\n",
    "    columns=\"pixel_values\", label_cols=\"label\", shuffle=True, batch_size=batch_size, collate_fn=data_collator\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6594267f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.losses import SparseCategoricalCrossentropy\n",
    "\n",
    "loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
    "model.compile(optimizer=optimizer, loss=loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "291534b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fit(tf_train_dataset, validation_data=tf_eval_dataset, epochs=num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a570d575",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
